/*
    Copyright 2006-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Definition of the nbody_data_grid and related classes. \ingroup makegrid

// $Id$

#ifndef __sphgrid__
#define __sphgrid__

#include "grid.h"

#include "blitz/array.h"
#include "boost/shared_ptr.hpp"
#include "boost/thread/mutex.hpp"
#include <vector>
#include <algorithm> 
#include "sphparticle.h"
#include "counter.h"

namespace mcrx {
  class nbody_data;
  class nbody_data_cell;
  class refinement_accuracy_data;
  class tolerance_checker;
  class nbody_data_grid;
  class check_particle_overlap_p;
  template <typename> class particle;
  class Snapshot; // forward decl
}

namespace CCfits {
  class FITS;
  class HDU;
}

int main (int, char**);

/** Contains quantities about the particles output by sfrhist.  These
 quantities are then calculated for each grid cell when the adaptive
 grid is constructed.  \ingroup makegrid */
class mcrx::nbody_data {
  friend class nbody_data_grid;
  friend int ::main(int, char**);
public:
  typedef mcrx::T_float T_float;
protected:
  T_float m_g_;   ///< Gas mass.
  vec3d p_g_;     ///< Gas momentum
  T_float L_bol_; ///< Bolometric luminosity.
  T_float SFR;    ///< Star-formation rate.
  T_float m_metals_; ///< Mass in metals (in the gas).
  T_float m_s_; ///< Stellar mass.
  vec3d p_s_;  ///< Stellar momentum
  T_float m_m_s;  ///< Mass in metals (in the stars).
  T_float age_m; ///< Mass-weighted stellar age.
  T_float age_l; ///< Luminosity-weighted stellar age.
  T_float gas_temp_m; ///< Mass-weighted gas temperature.
  T_float gas_teff_m; ///< Mass-weighted gas effective temperature.

  array_1 L_lambda; ///< Spectral energy distribution of stellar emission.

public:
  nbody_data (): m_g_ (0), p_g_(0,0,0), L_bol_ (0), SFR (0), m_metals_ (0), m_s_(0),
		 p_s_(0,0,0),
		 m_m_s(0),age_m(0),age_l(0),gas_temp_m(0),gas_teff_m(0),
		 L_lambda (make_thread_local_copy(array_1 ())) {};
  nbody_data (T_float m_gas, const vec3d& p_gas, T_float Lb, T_float s, T_float met,
	      T_float m_star, const vec3d& p_star, T_float starmet, T_float agem, 
	      T_float agel,
	      T_float mtemp, T_float mteff,
	      const array_1& Ll = array_1 ()):
    m_g_ (m_gas), p_g_(p_gas), L_bol_ (Lb), SFR (s), m_metals_ (met),
    m_s_(m_star), p_s_(p_star), m_m_s(starmet),age_m(agem),age_l(agel),
    gas_temp_m(mtemp), gas_teff_m(mteff), 
    L_lambda (make_thread_local_copy(Ll)) {};
  // copy constructor does NOT have reference semantics
  nbody_data (const nbody_data& d): m_g_ (d.m_g_), p_g_(d.p_g_), L_bol_(d.L_bol_),
				    SFR (d.SFR), m_metals_ (d.m_metals_), 
				    m_s_(d.m_s_), p_s_(d.p_s_), m_m_s(d.m_m_s),
				    age_m(d.age_m), age_l(d.age_l),
				    gas_temp_m(d.gas_temp_m),
				    gas_teff_m(d.gas_teff_m),
				    /// \todo no thread local copy here??
				    L_lambda (d.L_lambda.size()) {
    L_lambda = d.L_lambda;};

  ///@{ 
  /// \name Access operators
  T_float m_g() const { return m_g_;};
  T_float m_metals() const { return m_metals_;};
  const vec3d& p_g() const { return p_g_;};
  T_float L_bol() const { return L_bol_;};
  T_float gas_temp() const { return gas_temp_m/m_g_;};
  T_float gas_teff() const { return gas_teff_m/m_g_;};
  ///@}

  ///@{
  /// \name Arithmetic operators
  /// These operators are used when adding quantities in particles together.
  
  /// Assignment operator, resizes L_lambda as necessary.
  nbody_data& operator= (const nbody_data&);
  nbody_data& operator+= (const nbody_data& rhs) {
    m_g_+= rhs.m_g_; p_g_+= rhs.p_g_; L_bol_+= rhs.L_bol_; SFR += rhs.SFR;
    m_metals_ += rhs.m_metals_; 
    m_s_ += rhs.m_s_; p_s_+= rhs.p_s_; m_m_s += rhs.m_m_s; age_m += rhs.age_m;
    age_l += rhs.age_l; 
    gas_temp_m += rhs.gas_temp_m; gas_teff_m += rhs.gas_teff_m;
    L_lambda+= rhs.L_lambda; return *this;}
  /** Multiplication operator does element-wise multiplication.  Note
      that this operator does NOT operate on the vector members,
      i.e. L_lambda and momenta, because that doesn't make sense
      given how the operator is used.  */
  nbody_data& operator*= (const nbody_data& rhs) {
    // This operator does NOT operate on L_lambda 
    m_g_*= rhs.m_g_; L_bol_*= rhs.L_bol_; SFR*= rhs.SFR;
    m_metals_*= rhs.m_metals_; 
    m_s_ *= rhs.m_s_; m_m_s *= rhs.m_m_s; age_m *= rhs.age_m;
    age_l *= rhs.age_l; 
    gas_temp_m *= rhs.gas_temp_m; gas_teff_m *= rhs.gas_teff_m; 
    return *this;} 
  /// See operator*= for an important note about how this works.
  nbody_data operator*(const nbody_data& rhs) const {
    return nbody_data (*this)*= rhs;}
  nbody_data& operator*= (T_float rhs) {
    m_g_*= rhs; p_g_*= rhs; L_bol_*= rhs; SFR*= rhs; m_metals_*= rhs;
    m_s_ *= rhs; p_s_ *= rhs; m_m_s *= rhs; age_m *= rhs;
    age_l *= rhs; gas_temp_m *= rhs; gas_teff_m *= rhs;
    L_lambda*= rhs; return *this;} 
  nbody_data operator*(T_float rhs) const {
    return nbody_data (*this)*= rhs;}
  /** A "floating multiply and add"-operator.  This could have been
      implemented using some fancy operator Magic, but not worth
      it. */
  void add_to (/// The object to be added to this.
	       const nbody_data& d,
	       /// The scalar to multiply d with before adding.
	       T_float f,
	       /** A volume necessary in case there were extensive
                   quantities, which currently there aren't */
	       T_float vol) {
    assert (d.m_g_ >= 0);
    assert (d.L_bol_ >= 0);
    assert (m_g_ >= 0);
    assert (L_bol_ >= 0);
    m_g_ += d.m_g_*f; p_g_ += d.p_g_*f; L_bol_+= d.L_bol_*f; SFR += d.SFR*f;
    m_metals_ += d.m_metals_*f;
    m_s_ += d.m_s_*f; p_s_ += d.p_s_*f; m_m_s += d.m_m_s*f; age_m += d.age_m*f;
    age_l += d.age_l*f; 
    gas_temp_m += d.gas_temp_m*f;
    gas_teff_m += d.gas_teff_m*f;
    if (d.L_lambda.size() > 0) L_lambda += d.L_lambda*f;}
  ///@}
};

/** This class is basically just nbody_data, but with the added
    (static) functions used when creating the grid refinements.  
    \ingroup makegrid */
class mcrx::nbody_data_cell: public nbody_data {
public:
  nbody_data_cell() : nbody_data() {};
  nbody_data_cell (T_float m_gas, const vec3d& p_gas, T_float Lb, T_float s, T_float met,
		   T_float m_star, const vec3d& p_star, T_float starmet, 
		   T_float agem, T_float agel,
		   T_float mtemp, T_float mteff,
		   const array_1& Ll = array_1 ()):
    nbody_data (m_gas, p_gas, Lb, s, met, m_star, p_star, starmet, agem, agel, 
		mtemp, mteff, Ll) {};
  /** A "static downcast", converting an nbody_data object into
      nbody_data_cell.  This is perfectly safe since there are no
      extra data members in nbody_data_cell.  */
  nbody_data_cell (const nbody_data& d):
    nbody_data (d) {};

  /** Calculates the proper quantities in a cell which is the
      unification of the number of smaller cells.  Because all
      quantities are extensive, the unification cell corresponding to
      a group of cells is simply the sum of their quantities.  If
      there were extensive quantities, it would be the average, but
      there are none.  */
  static nbody_data_cell unification (const nbody_data_cell& sum, int) {
    return sum;};
};


/** Keeps information about the grid refinement accuracy when creating
    the hierarchical grid.  This information is necessary for the
    calculation of the standard deviation of the quantities in the
    subgrids. \ingroup makegrid */
class mcrx::refinement_accuracy_data {
public:
  nbody_data sum; ///< Sum of quantities.
  nbody_data sumsq; ///< Sum of squares of quantities.
  int n; ///< Number of terms in the sum.
  bool all_leaves; ///< True if all cells are leaf cells.

  refinement_accuracy_data (nbody_data s, nbody_data ssq, int nn, bool al):
    sum (s), sumsq (ssq), n (nn), all_leaves (al) {}; 

  /** Increment operator used to accumulate over cells.  Note that
      all_leaves is AND'ed together.  */
  refinement_accuracy_data& operator+= (const refinement_accuracy_data& rhs) {
    sum += rhs.sum; sumsq += rhs.sumsq; n+= rhs.n; 
    all_leaves &= rhs.all_leaves;
    return *this;};
};

/** This class is a predicate that returns whether a cell can be
    unified. It keeps tolerances and data necessary to test
    unification. */
class mcrx::tolerance_checker {
private:
  typedef adaptive_grid<nbody_data_cell>::T_cell T_cell;

  const nbody_data_cell tolerance_;
  const nbody_data_cell tolerance_absolute_;
  const T_float gas_metallicity_;
  const T_float max_metal_column_;
  T_float L_bol_tot_;

public: 
  tolerance_checker(const nbody_data_cell& tolerance, 
		    const nbody_data_cell& tol_absolute, 
		    T_float gm, T_float maxcol) :
    tolerance_(tolerance), tolerance_absolute_(tol_absolute),
    gas_metallicity_(gm), max_metal_column_(maxcol) {};

  void set_L_bol (T_float l) {L_bol_tot_=l;};
  const nbody_data_cell& tolerance() const { return tolerance_;};
  const nbody_data_cell& tolerance_absolute() const {
    return tolerance_absolute_;};
  T_float max_metal_column() const { return max_metal_column_;};

  /** Tests if a collection of cells can be unified.  The function
      returns true if the collection of cells represented by sum and
      sumsq have a lower fractional deviation than tolerance or if the
      absolute standard deviation is a smaller than absolute for all
      quantities, and the column density of metals is smaller than the
      specified maximum. This is used for ensuring that the grid
      resolves a specific optical depth which is important for the
      temperature calculations.  */
  bool unify_cell_p (const nbody_data_cell& sum,
		     const nbody_data_cell& sumsq,
		     int n, const T_cell& c) const;

  /** Tests if a cell should be refined because it has too large metal
      column density, based on the particles that are inside it. */
  bool refine_cell_p (const check_particle_overlap_p& cpo,
		      const T_cell& c) const;
  void print_tolerances() const;
};


/** A grid containing data from the nbody particles, used for creating
    the hierarchical grid. \ingroup makegrid */
class mcrx::nbody_data_grid:
  public adaptive_grid <nbody_data_cell>
{
public:
  typedef adaptive_grid <nbody_data_cell> T_base;
  friend int ::main(int, char**);
  typedef particle<nbody_data> T_particle;
private:
  ///@{
  /// \name Refinement variables
  /// These members control the behavior of the adaptive refinement.
  
  /** Factor determining when refinement is truncated based on the
      size of the particles in the cell.  If the cell size is smaller
      than size_factor times better radius of the smallest contained
      particle, no further refinement is done.  */
  T_float size_factor;
  /// Maximum refinement level.
  int max_level;

  /// Predicate object for checking whether cells can be unified.
  std::auto_ptr<tolerance_checker> tol_checker_;

  /// "Default" cell data, used if there are no particles in the cell.
  T_data data_zero;
  counter ctr;
  ///@}

  ///@{
  /// \name Units 
  /// The units are simply propagated from snapshot file to output file.
  std::string mass_unit;
  std::string time_unit;
  std::string length_unit; 
  std::string L_bol_unit;
  std::string SFR_unit;
  std::string temp_unit;
  std::string L_lambda_unit;
  ///@}
  
  /// Total grid quantities, summed over all cells
  nbody_data_cell total_quantities;

  ///@{
  /// \name Threading data structures
  /// These variables are used to manage the multithreaded refinement.

  mutable int n_threads; ///< Number of threads to use.
  /** Number of levels to refine before dividing the work among
      threads.  This is a load-balancing issue, see the
      makegrid-README.  */
  int work_chunk_levels;
  /// Mutex used to protect access to the vector of cells to process.
  mutable boost::mutex cell_stack_mutex;
  /// The vector of cells that needs to be processed.
  std::vector<std::pair<T_cell*, int> > cell_stack;
  /// The particles that the grid will be based on.
  std::vector<T_particle*> particle_list_;
  class thread_start; // function object used to start threads
  /// Keeps statistics about how many cells are at which refinement level.
  std::vector<int> creation_stats;
  void pop_cell_and_refine ();
  ///@}

  ///@{
  /** \name Load-balancing threading data structures 
      These variables are used to manage the multithreaded creation of
      a list of approximately load-balanced cells to work on.  (Yeah,
      it's weird.)  The cell_stack_mutex is also used for these threads. */

  /// Keeps track of the data for the load-balancing. \ingroup makegrid
  class balance_queue_data {
  public:
    typedef mcrx::nbody_data_grid::T_cell T_cell;
    
    T_cell* cell;///< Pointer to the cell
    /// List of particles that are in this cell 
    boost::shared_ptr<std::vector<T_particle*> > particle_list_;
    int level; ///< Current refinement level
    balance_queue_data () {};
    balance_queue_data (T_cell*c, boost::shared_ptr<std::vector<T_particle*> > pl,
			int l): cell (c), particle_list_ (pl), level (l) {};
  };
  /// The vector of cells that have been balanced 
  std::vector<balance_queue_data> balance_queue;
  class balance_thread_start;
  void process_balance_queue_cell ();
  ///@}
  

  ///@{
  /// \name Functions for grid building and refinement
  
  // this function builds the grid using the particles in particle_list
  void build_grid (); 
  void create_balanced_queue ();
  refinement_accuracy_data recursive_refine (T_grid& g,
					     const std::vector<T_particle*>& pl,
					     int level, std::vector<int>&);
  refinement_accuracy_data recursive_refine_body (T_cell* c, 
						  const std::vector<T_particle*>& pl,
						  int level, std::vector<int>&);
  refinement_accuracy_data project_particles (T_cell& c,
			  const std::vector<T_particle*> & pl);
  void unrefine_if_possible (T_cell& c, refinement_accuracy_data& racc);
  ///@}
  
public:
  nbody_data_grid (const vec3d & mi, const vec3d & ma, const vec3i & nn):
    adaptive_grid<nbody_data_cell> (mi, ma, nn),
    ctr (1), n_threads (0), work_chunk_levels (0) {};
  /** Construct a nbody_data_grid reading the structure from the
      specified GRIDSTRUCTURE HDU, while also loading the length
      unit. */
  nbody_data_grid (CCfits::ExtHDU& input, 
		   const vec3d& translate_origin=vec3d(0,0,0));

  bool load_snapshot (const mcrx::Snapshot&,
		      const int ml, const T_float size_fudge,
		      const tolerance_checker&,
		      CCfits::HDU*info = 0, bool use_SED = true);
  void save_data (CCfits::FITS&, const std::string&, const std::string&) const;
  /** Loads grid cell data from a FITS GRIDDATA HDU into an existing
      grid structure.  */
  void load_data (CCfits::ExtHDU&);
  void use_threads (int n) const {n_threads = n;};
  void set_work_chunk (int n) {work_chunk_levels = n;};

  /// Saves the grid structure to FITS file. 
  // is this an override or hiding?  I think it's hiding.
  void save_structure (CCfits::FITS& file, const std::string& hdu) const {
    T_base::save_structure (file, hdu, length_unit);};
};

/** Predicate used to copy particles that overlap with a grid cell,
    while keeping track of the minimum size of the particles. 
    \ingroup makegrid */
class mcrx::check_particle_overlap_p :
  public std::unary_function<mcrx::nbody_data_grid::T_particle*, bool> {
private:
  friend int main(int, char**);
  typedef mcrx::nbody_data_grid::T_float T_float;
  typedef mcrx::nbody_data_grid::T_particle T_particle;
  typedef mcrx::nbody_data_grid::T_cell T_cell;
  
  const T_cell& c;
  T_float gas_metallicity_;
  // Members that return data MUST be shared_ptrs because the object
  // is copied into the algorithm.
  boost::shared_ptr<T_float> min_size_;
  boost::shared_ptr<T_float> max_rho_;

public:
  check_particle_overlap_p (const T_cell& cc, T_float gm=0) :
    c (cc), gas_metallicity_ (gm), 
    min_size_ (new T_float (blitz::huge (T_float () ))),
    max_rho_(new T_float()) {};

  /// Returns the size of the smallest particles seen so far.
  T_float min_size () const {return *min_size_;};
  T_float max_rho () const {return *max_rho_;};
  /** Returns true if the particle overlaps with the cell.  Also keeps
      track of the minimum size and the maximum density of the
      overlapping particles. */
  bool operator () (const T_particle*const p) {
    const bool o =p->overlap(c);
    if (o) {
      const T_float rho_est = (p->data().m_metals() + 
			       p->data().m_g()*gas_metallicity_)/
	(p->radius()*p->radius()*p->radius());
      *max_rho_ = std::max(*max_rho_, rho_est );
      *min_size_ = std::min (*min_size_, p->radius());
    }
    return o;
  };
};

/** Functor used by boost threads to start the refinement threads.  It simply
    calls pop_cell_and_refine () for the grid. \ingroup makegrid */
class mcrx::nbody_data_grid::thread_start {
private:
  nbody_data_grid*self;
public:
  thread_start (nbody_data_grid*g): self (g) {};
  void operator () () {
    self->pop_cell_and_refine ();
  };
};

/** Functor used by boost threads to start the load balancing threads.  It
    simply calls process_balance_queue_cell () for the grid.
    \ingroup makegrid */
class mcrx::nbody_data_grid::balance_thread_start {
private:
  nbody_data_grid*self;
public:
  balance_thread_start (nbody_data_grid*g): self (g) {};
  void operator () () {
    self->process_balance_queue_cell ();
  };
};


#endif

